{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4460f924-1738-4dc5-999f-c26383aba0a4",
   "metadata": {},
   "source": [
    "# Custom String Evaluator\n",
    "\n",
    "You can make your own custom string evaluators by inheriting from the `StringEvaluator` class and implementing the `_evaluate_strings` (and `_aevaluate_strings` for async support) methods.\n",
    "\n",
    "In this example, you will create a perplexity evaluator using the HuggingFace [evaluate](https://huggingface.co/docs/evaluate/index) library.\n",
    "[Perplexity](https://en.wikipedia.org/wiki/Perplexity) is a measure of how well the generated text would be predicted by the model used to compute the metric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "90ec5942-4b14-47b1-baff-9dd2a9f17a4e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# %pip install evaluate > /dev/null"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "54fdba68-0ae7-4102-a45b-dabab86c97ac",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from typing import Any, Optional\n",
    "\n",
    "from langchain.evaluation import StringEvaluator\n",
    "from evaluate import load\n",
    "\n",
    "\n",
    "class PerplexityEvaluator(StringEvaluator):\n",
    "    \"\"\"Evaluate the perplexity of a predicted string.\"\"\"\n",
    "\n",
    "    def __init__(self, model_id: str = \"gpt2\"):\n",
    "        self.model_id = model_id\n",
    "        self.metric_fn = load(\n",
    "            \"perplexity\", module_type=\"metric\", model_id=self.model_id, pad_token=0\n",
    "        )\n",
    "\n",
    "    def _evaluate_strings(\n",
    "        self,\n",
    "        *,\n",
    "        prediction: str,\n",
    "        reference: Optional[str] = None,\n",
    "        input: Optional[str] = None,\n",
    "        **kwargs: Any,\n",
    "    ) -> dict:\n",
    "        results = self.metric_fn.compute(\n",
    "            predictions=[prediction], model_id=self.model_id\n",
    "        )\n",
    "        ppl = results[\"perplexities\"][0]\n",
    "        return {\"score\": ppl}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "52767568-8075-4f77-93c9-80e1a7e5cba3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "evaluator = PerplexityEvaluator()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "697ee0c0-d1ae-4a55-a542-a0f8e602c28a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using pad_token, but it is not set yet.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "467109d44654486e8b415288a319fc2c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'score': 190.3675537109375}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluator.evaluate_strings(prediction=\"The rains in Spain fall mainly on the plain.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5089d9d1-eae6-4d47-b4f6-479e5d887d74",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using pad_token, but it is not set yet.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d3266f6f06d746e1bb03ce4aca07d9b9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'score': 1982.0709228515625}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# The perplexity is much higher since LangChain was introduced after 'gpt-2' was released and because it is never used in the following context.\n",
    "evaluator.evaluate_strings(prediction=\"The rains in Spain fall mainly on LangChain.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eaa178f-6ba3-47ae-b3dc-1b196af6d213",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
